﻿using jvm.classfile;
using jvm.classpath;
using jvm.rtda.frame;
using jvm.rtda.heap.clazz.clazz;
using jvm.rtda.heap.pool;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace jvm.rtda.heap.clazz
{
    class JClassLoader
    {
        public Classpath classpath;
        public Dictionary<string, JClass> classMap = new Dictionary<string, JClass>();
        bool log;

        public JClassLoader(Classpath classpath, bool log)
        {
            this.classpath = classpath;
            this.log = log;
            LoadBasicClasses();
            LoadPrimitiveClasses();
        }

        void LoadBasicClasses()
        {
            JClass jLClass = LoadClass("java/lang/Class");
            foreach (JClass jClass in classMap.Values)
            {
                JObject jLClassObject = jLClass.NewObject();
                jClass.jLClassObject = jLClassObject;
                jLClassObject.extClass = jClass;
            }
        }

        void LoadPrimitiveClasses()
        {
            foreach (string primitiveClass in ClassNameHelper.map.Keys)
            {
                LoadPrimitiveClass(primitiveClass);
            }
        }

        void LoadPrimitiveClass(string classname)
        {
            JClass jClass = new JClass
            {
                accessFlags = AccessFlags.ACC_PUBLIC,
                name = classname,
                loader = this,
                inited = true
            };
            jClass.jLClassObject = classMap["java/lang/Class"].NewObject();
            jClass.jLClassObject.extClass = jClass;
            classMap[classname] = jClass;
        }

        public JClass LoadClass(string classname)
        {
            if (classMap.ContainsKey(classname))
            {
                return classMap[classname];
            }
            JClass jClass;
            if (classname[0] == '[')
            {
                jClass = LoadArray(classname);
            }
            else
            {
                jClass = Load(classname);
            }

            if (classMap.ContainsKey("java/lang/Class"))
            {
                JClass javaLangClass = classMap["java/lang/Class"];
                jClass.jLClassObject = javaLangClass.NewObject();
                jClass.jLClassObject.extClass = jClass;
            }
            classMap[classname] = jClass;
            return jClass;
        }

        JClass LoadArray(string classname)
        {
            JClass jClass = new JClass
            {
                accessFlags = AccessFlags.ACC_PUBLIC,
                name = classname,
                loader = this,
                inited = true,
                superClass = LoadClass("java/lang/Object"),
                interfaces = new JClass[] {
                    LoadClass("java/lang/Cloneable"),
                    LoadClass("java/io/Serializable")
                }
            };
            return jClass;
        }

        JClass Load(string classname)
        {
            Tuple<byte[], Entry> tuple = ReadClass(classname);
            Entry entry = tuple.Item2;
            byte[] date = tuple.Item1;
            JClass jClass = DefineClass(date);
            Link(jClass);
            if (log)
            {
                Console.WriteLine("[Loaded {0} from {1}]", classname, entry);
            }
            return jClass;
        }

        Tuple<byte[], Entry> ReadClass(string classname)
        {
            Tuple<bool, byte[], Entry> tuple = classpath.ReadClass(classname);
            if (!tuple.Item1)
            {
                Console.WriteLine("java.lang.ClassNotFoundException: {0}", classname);
                Environment.Exit(0);
            }
            return new Tuple<byte[], Entry>(tuple.Item2, tuple.Item3);
        }

        JClass DefineClass(byte[] data)
        {
            JClass jClass = ParseClass(data);
            jClass.loader = this;
            ResolveSuperClass(jClass);
            ResolveInterfaces(jClass);
            return jClass;
        }

        JClass ParseClass(byte[] data)
        {
            ClassFile classFile = new ClassFile();
            classFile.Parse(data);
            return new JClass(classFile);
        }

        void ResolveSuperClass(JClass jClass)
        {
            if (!jClass.name.Equals("java/lang/Object"))
            {
                jClass.superClass = LoadClass(jClass.superClassName);
            }
        }

        void ResolveInterfaces(JClass jClass)
        {
            jClass.interfaces = new JClass[jClass.interfaceNames.Length];
            for (int i = 0; i < jClass.interfaces.Length; i++)
            {
                jClass.interfaces[i] = LoadClass(jClass.interfaceNames[i]);
            }
        }

        void Link(JClass jClass)
        {
            Verify(jClass);
            Prepare(jClass);
        }

        void Verify(JClass jClass)
        {

        }

        void Prepare(JClass jClass)
        {
            CalcInstanceFieldSlotIds(jClass);
            CalcStaticFieldSlotIds(jClass);
            AllocAndInitStaticVars(jClass);
        }

        void CalcInstanceFieldSlotIds(JClass jClass)
        {
            uint slotId = 0;
            if (jClass.superClass != null)
            {
                slotId = jClass.superClass.instanceSlotCount;
            }
            foreach (JField field in jClass.fields)
            {
                if (!field.Is(AccessFlags.ACC_STATIC))
                {
                    field.slotId = slotId;
                    slotId++;
                    if (field.IsLongOrDouble())
                    {
                        slotId++;
                    }
                }
            }
            jClass.instanceSlotCount = slotId;
        }

        void CalcStaticFieldSlotIds(JClass jClass)
        {
            uint slotId = 0;
            foreach (JField field in jClass.fields)
            {
                if (field.Is(AccessFlags.ACC_STATIC))
                {
                    field.slotId = slotId;
                    slotId++;
                    if (field.IsLongOrDouble())
                    {
                        slotId++;
                    }
                }
            }
            jClass.staticSlotCount = slotId;
        }

        void AllocAndInitStaticVars(JClass jClass)
        {
            jClass.staticVars = new Slots(jClass.staticSlotCount);
            foreach (JField field in jClass.fields)
            {
                if (field.Is(AccessFlags.ACC_STATIC) && field.Is(AccessFlags.ACC_FINAL))
                {
                    InitStaticFinalVar(jClass, field);
                }
            }
        }

        void InitStaticFinalVar(JClass jClass, JField field)
        {
            Slots slots = jClass.staticVars;
            JConstantPool pool = jClass.constantPool;
            uint valueIndex = field.constValueIndex;
            if (valueIndex > 0)
            {
                // 4.3.2
                switch (field.descriptor)
                {
                    case "Z":
                    case "B":
                    case "C":
                    case "S":
                    case "I":
                        slots.SetInt(field.slotId, pool.GetConstant<int>(valueIndex));
                        break;
                    case "J":
                        slots.SetLong(field.slotId, pool.GetConstant<long>(valueIndex));
                        break;
                    case "F":
                        slots.SetFloat(field.slotId, pool.GetConstant<float>(valueIndex));
                        break;
                    case "D":
                        slots.SetDouble(field.slotId, pool.GetConstant<double>(valueIndex));
                        break;
                    case "Ljava/lang/String;":
                        string str = pool.GetConstant<string>(valueIndex);
                        slots.SetRef(field.slotId, StringPool.StringObject(jClass.loader, str));
                        break;
                }
            }
        }
    }
}
